{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Calibration\n",
    "The following tutorial shows how to calibrate a model."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e03ab8",
   "metadata": {},
   "source": [
    "## Train and predict scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n",
      "Epoch 1/10\n",
      "29/29 [==============================] - 3s 103ms/step - loss: 6736.2100\n",
      "Epoch 2/10\n",
      "29/29 [==============================] - 1s 37ms/step - loss: 6734.2466\n",
      "Epoch 3/10\n",
      "29/29 [==============================] - 1s 33ms/step - loss: 6722.6729\n",
      "Epoch 4/10\n",
      "29/29 [==============================] - 1s 35ms/step - loss: 6671.5366\n",
      "Epoch 5/10\n",
      "29/29 [==============================] - 1s 32ms/step - loss: 6529.1528\n",
      "Epoch 6/10\n",
      "29/29 [==============================] - 1s 32ms/step - loss: 6261.3403\n",
      "Epoch 7/10\n",
      "29/29 [==============================] - 1s 35ms/step - loss: 5892.9292\n",
      "Epoch 8/10\n",
      "29/29 [==============================] - 1s 32ms/step - loss: 5488.4858\n",
      "Epoch 9/10\n",
      "29/29 [==============================] - 1s 31ms/step - loss: 5097.4917\n",
      "Epoch 10/10\n",
      "29/29 [==============================] - 1s 32ms/step - loss: 4741.0933\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x28a904610>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=1, \n",
    "                                     k=100,\n",
    "                                     scoring_type='ComplEx')\n",
    "\n",
    "\n",
    "# compile the model with loss and optimizer\n",
    "model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "\n",
    "dataset = load_fb15k_237()\n",
    "\n",
    "model.fit(dataset['train'],\n",
    "             batch_size=10000,\n",
    "             epochs=10)     \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "24cb1627",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-2.8158133 -1.406638  -1.2758102 ...  6.1383247  6.174588   6.233885 ]\n",
      "[ 3834  4066 18634 ...  4972  2802 15757]\n"
     ]
    }
   ],
   "source": [
    "# The predicted scores are unbounded. \n",
    "# So it is hard to say just by looking at a single score if it is a good or bad score\n",
    "pred_out = model.predict(dataset['test'], batch_size=10000)\n",
    "\n",
    "# print the sorted score\n",
    "print(np.sort(pred_out))\n",
    "# rank the triples based on scores\n",
    "print(np.argsort(pred_out))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f22b45e",
   "metadata": {},
   "source": [
    "## Model calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fd478620",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:47<00:00,  2.13it/s]\n"
     ]
    }
   ],
   "source": [
    "# calibrate on the test set\n",
    "model.calibrate(dataset['test'],        # Dataset to calibrate on\n",
    "                batch_size=500,         # Batch size to be used for calibration\n",
    "                positive_base_rate=0.8, # prior which indicates what percentage of the dataset might be correct\n",
    "                epochs=100,             # Number of epochs\n",
    "                verbose=True)             "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f625f2a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.00527451 0.05239511 0.06431627 ... 0.9999361  0.9999398  0.9999454 ]\n",
      "[ 3834  4066 18634 ...  4972  2802 15757]\n"
     ]
    }
   ],
   "source": [
    "# use predict_proba to predict the calibrated scores\n",
    "# You will observe that the predicted scores are now bounded and between [0-1]\n",
    "out = model.predict_proba(dataset['test'], batch_size=10000)\n",
    "\n",
    "# if we now look at the sorted scores and ranks, it doesnt change from earlier\n",
    "print(np.sort(out))\n",
    "print(np.argsort(out))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
